// SPDX-FileCopyrightText: Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
// SPDX-License-Identifier: BSD-3

#include <cub/util_arch.cuh>
#include <cub/util_ptx.cuh>

#include <cuda/__cmath/pow2.h>

#include <c2h/catch2_test_helper.h>

template <int logical_warp_threads>
struct total_warps_t
{
private:
  static constexpr unsigned int total_warps =
    (::cuda::is_power_of_two(logical_warp_threads)) ? cub::detail::warp_threads / logical_warp_threads : 1;

public:
  static constexpr unsigned int value()
  {
    return total_warps;
  }
};

bool is_lane_involved(unsigned int member_mask, unsigned int lane)
{
  return member_mask & (1 << lane);
}

using logical_warp_threads      = c2h::iota<1, 32>;
using power_of_two_warp_threads = c2h::enum_type_list<int, 1, 2, 4, 8, 16, 32>;

C2H_TEST("Warp mask ignores lanes before current logical warp", "[mask][warp]", power_of_two_warp_threads)
{
  constexpr int logical_warp_thread  = c2h::get<0, TestType>::value;
  constexpr unsigned int total_warps = total_warps_t<logical_warp_thread>::value();

  for (unsigned int warp_id = 0; warp_id < total_warps; warp_id++)
  {
    const unsigned int warp_mask  = cub::WarpMask<logical_warp_thread>(warp_id);
    const unsigned int warp_begin = logical_warp_thread * warp_id;

    for (unsigned int prev_warp_lane = 0; prev_warp_lane < warp_begin; prev_warp_lane++)
    {
      REQUIRE_FALSE(is_lane_involved(warp_mask, prev_warp_lane));
    }
  }
}

C2H_TEST("Warp mask involves lanes of current logical warp", "[mask][warp]", logical_warp_threads)
{
  constexpr int logical_warp_thread  = c2h::get<0, TestType>::value;
  constexpr unsigned int total_warps = total_warps_t<logical_warp_thread>::value();

  for (unsigned int warp_id = 0; warp_id < total_warps; warp_id++)
  {
    const unsigned int warp_mask  = cub::WarpMask<logical_warp_thread>(warp_id);
    const unsigned int warp_begin = logical_warp_thread * warp_id;
    const unsigned int warp_end   = warp_begin + logical_warp_thread;

    for (unsigned int warp_lane = warp_begin; warp_lane < warp_end; warp_lane++)
    {
      REQUIRE(is_lane_involved(warp_mask, warp_lane));
    }
  }
}

C2H_TEST("Warp mask ignores lanes after current logical warp", "[mask][warp]", logical_warp_threads)
{
  constexpr int logical_warp_thread  = c2h::get<0, TestType>::value;
  constexpr unsigned int total_warps = total_warps_t<logical_warp_thread>::value();

  for (unsigned int warp_id = 0; warp_id < total_warps; warp_id++)
  {
    const unsigned int warp_mask  = cub::WarpMask<logical_warp_thread>(warp_id);
    const unsigned int warp_begin = logical_warp_thread * warp_id;
    const unsigned int warp_end   = warp_begin + logical_warp_thread;

    for (unsigned int post_warp_lane = warp_end; post_warp_lane < cub::detail::warp_threads; post_warp_lane++)
    {
      REQUIRE_FALSE(is_lane_involved(warp_mask, post_warp_lane));
    }
  }
}
